from typing import Optional, Dict
from datasets import load_dataset
from dataclasses import dataclass

from utils.string_utils import truncate
from _datasets.utils import prepare_dataset


@dataclass
class DatasetConfig:
    name: str
    num_examples: Optional[int] = None
    split: Optional[str] = None

    def get_dataset(self, to_truncate=False, max_length=0):
        hf_datasets: Dict[str, Dict] = {
            "openai_summarize_scores": {
                "path": "openai/summarize_from_feedback",
                "name": "axis",
                "split": "test",
            },
            "openai_summarize_comparisons": {
                "path": "openai/summarize_from_feedback",
                "name": "comparisons",
                "split": "train",
            },
            "cnn_dailymail": {
                "path": "cnn_dailymail",
                "name": "3.0.0",
                "split": "train",
            },
            "billsum": {
                "path": "billsum",
                "split": "train",
            },
            "scientific_papers_robustness": {
                "path": "scientific_papers",
                "name": "pubmed",
                "split": "train",
                "trust_remote_code": True,
            },
            "biorxiv-clustering-p2p": {
                "path": "mteb/biorxiv-clustering-p2p",
                "split": "test",
            },
            "biorxiv-clustering-s2s": {
                "path": "mteb/biorxiv-clustering-p2p",
                "split": "test",
            },
            "stackexchange-clustering-p2p": {
                "path": "mteb/stackexchange-clustering-p2p",
                "split": "test",
            },
            "stackexchange-clustering": {
                "path": "mteb/stackexchange-clustering",
                "split": "test",
            },
            "reddit-clustering-p2p": {
                "path": "mteb/reddit-clustering-p2p",
                "split": "test",
            },
            "reddit-clustering": {
                "path": "mteb/reddit-clustering",
                "split": "test",
            },
            "arxiv-clustering-p2p": {
                "path": "mteb/arxiv-clustering-p2p",
                "split": "test",
            },
            "arxiv-clustering-s2s": {
                "path": "mteb/arxiv-clustering-s2s",
                "split": "test",
            },
            "medrxiv-clustering-p2p": {
                "path": "mteb/medrxiv-clustering-p2p",
                "split": "test",
            },
            "medrxiv-clustering-s2s": {
                "path": "mteb/medrxiv-clustering-s2s",
                "split": "test",
            },
            "twentynewsgroups-clustering": {
                "path": "mteb/twentynewsgroups-clustering",
                "split": "test",
            },
            "scientific_papers_sensitivity": {
                "path": "scientific_papers",
                "name": "pubmed",
                "trust_remote_code": True,
            },
            "paul_graham": {
                "path": "sgoel9/paul_graham_essays",
            },
            "amazon_polarity": {
                "path": "amazon_polarity",
            },
            "arguana": {
                "path": "BeIR/arguana",
                "name": "corpus",
                "split": "corpus",
            },
            "reddit": {
                "path": "mteb/reddit-clustering-p2p",
                "split": "test",
            },
        }

        if self.name in hf_datasets:
            kwargs = hf_datasets[self.name].copy()
            split = self.split or kwargs.pop("split", None) or "train"

            ds = load_dataset(**kwargs, split=split).to_pandas()

            if self.num_examples:
                ds = ds[: self.num_examples]

            ds = prepare_dataset(self.name, ds)
        else:
            raise ValueError(f"Unrecognized Dataset {self.name}")

        if to_truncate:
            assert max_length != 0, f"Specify the max length {max_length}"
            for column in ds.columns:
                if ds[column].dtype == object:
                    ds[column] = truncate(ds[column], max_length)

        return ds
